import torch 
import tqdm
import numpy as np
from torch.utils.data import DataLoader


def get_metric_function(metric_name):
    if metric_name == "logistic_accuracy":
        return logistic_accuracy

    if metric_name == "softmax_accuracy":
        return softmax_accuracy

    elif metric_name == "softmax_loss":
        return softmax_loss

    elif metric_name == "logistic_loss":
        return logistic_loss

    else:
        print("Metric name not found.")


@torch.no_grad()
def compute_metric_on_dataset(model, dataset, metric_name):
    metric_function = get_metric_function(metric_name)
    
    model.eval()

    loader = DataLoader(dataset, drop_last=False, batch_size=1024)
    print("> Computing %s..." % (metric_name))

    score_sum = 0.
    for images, labels in tqdm.tqdm(loader):

        score_sum += metric_function(model, images, labels).item() * images.shape[0] 
            
    score = float(score_sum / len(loader.dataset))

    return score


@torch.no_grad()
def estimate_eps_f(model, dataset, metric_name, n_samples, sample_size, factor):
    metric_function = get_metric_function(metric_name)

    model.eval()

    loader = DataLoader(dataset, shuffle=True, drop_last=False, batch_size=sample_size)
    print("> Computing %s..." % (metric_name))

    function_evals = []
    count = 0
    for images, labels in tqdm.tqdm(loader):
        if count >= n_samples:
            break
        # images, labels = images.cuda(), labels.cuda()

        function_evals.append(metric_function(model, images, labels).item())
        count += 1

    model.train()

    return np.std(function_evals) * factor


def softmax_loss(model, images, labels, backwards=False):
    logits = model(images)
    criterion = torch.nn.CrossEntropyLoss(reduction="mean")
    loss = criterion(logits, labels.view(-1))

    if backwards and loss.requires_grad:
        loss.backward()

    return loss


def logistic_loss(model, images, labels, backwards=False):
    logits = model(images)
    criterion = torch.nn.BCEWithLogitsLoss(reduction="mean")
    loss = criterion(logits.view(-1), labels.view(-1))

    if backwards and loss.requires_grad:
        loss.backward()

    return loss


def logistic_accuracy(model, images, labels):
    logits = torch.sigmoid(model(images)).view(-1)
    pred_labels = (logits > 0.5).float().view(-1)
    acc = (pred_labels == labels).float().mean()

    return acc


def softmax_accuracy(model, images, labels):
    logits = model(images)
    pred_labels = logits.argmax(dim=1)
    acc = (pred_labels == labels).float().mean()

    return acc